import wandb
import numpy as np
from omegaconf import OmegaConf
from collections import Counter
from torch.utils.data import Subset
import json

def get_transfer_configs(wandb_run_id: str):
    configs = {}
    api = wandb.Api()
    run = api.run(wandb_run_id)  
    return {key: value["value"] for key,value in json.loads(run.json_config).items()}

def get_configs(wandb_run_id: str):
    configs = {}
    api = wandb.Api()
    run = api.run(wandb_run_id)
    num_clients = run.config["num_clients"]
    clients = {
        f"client_{i}":
            {
                "val_acc": run.summary[f"client-{i}_best-val-acc"],
                "test_acc": run.summary[f"client-{i}_best-test-acc"],
                "model_path": run.summary[f"client-{i}_best-model-path"],
                "train_data_indices": run.summary['splits_info'][f'client_{i}']['train_data_indices'],
                "val_data_indices": run.summary['splits_info'][f'client_{i}']['val_data_indices'],
                "train_data_distribution": dict(run.summary['splits_info'][f'client_{i}']['train_labels_count'])
            }
        for i in range(num_clients)
    }
    model = {}
    for k, v in run.config.items():
        key = k.split('/')
        if key[0] == "model":
            model[key[-1]] = v
    if run.config.get("trained_gen") or run.config.get("train_gen"):
        gen = {}
        for k, v in run.config.items():
            key = k.split('/')
            if key[0] == "gan_model":
                gen[key[-1]] = v
        configs["gen"] = gen

    configs["num_clients"] = num_clients
    configs["clients"] = clients
    configs["model"] = model
    configs["old_datamodule"] = expand_flatdict_to_tree(run.config, "datamodule")["datamodule"]
    return OmegaConf.create(configs)



def get_configs2(wandb_run_id: str):
    configs = {}
    api = wandb.Api()
    run = api.run(wandb_run_id)
    num_clients = run.config["num_clients"]
    clients = {
        f"client_{i}":
            {
                "val_acc": run.summary[f"client-{i}_best-val-acc"],
                "test_acc": run.summary[f"client-{i}_best-test-acc-round0"],
                "model_path": run.summary[f"client-{i}_best-model-path"],
            }
        for i in range(num_clients)
    }
    model = {}
    for k, v in run.config.items():
        key = k.split('/')
        if key[0] == "model":
            model[key[-1]] = v

    configs["num_clients"] = num_clients
    configs["clients"] = clients
    configs["model"] = model
    # configs["old_datamodule"] = expand_flatdict_to_tree(run.config, "datamodule")["datamodule"]
    return OmegaConf.create(configs)



# def get_configs(wandb_run_id: str):
#     configs = {}
#     # clients = {}
#     api = wandb.Api()
#     run = api.run(wandb_run_id)
#     num_clients = run.config["num_clients"]
#     clients = {
#         f"client_{i}":
#             {
#                 "val_acc": run.summary[f"client-{i}_best-val-acc"],
#                 "test_acc": run.summary[f"client-{i}_best-test-acc"],
#                 "model_path": run.summary[f"client-{i}_best-model-path"],
#                 "train_data_indices": dict(run.summary['splits_info'][f"client_{i}"]['train_data_indices'])['value'],
#                 "val_data_indices": dict(run.summary['splits_info'][f"client_{i}"]['val_data_indices'])['value'],
#                 "train_data_distribution": dict(run.summary['splits_info'][f'client_{i}']['train_labels_count'])
#             }
#         for i in range(num_clients)
#     }
#     clients = dict(clients)
#     model = {}
#     for k, v in run.config.items():
#         key = k.split('/')
#         if key[0] == "model":
#             model[key[-1]] = v
#     if run.config.get("trained_gen") or run.config.get("train_gen"):
#         gen = {}
#         for k, v in run.config.items():
#             key = k.split('/')
#             if key[0] == "gan_model":
#                 gen[key[-1]] = v
#         configs["gen"] = gen
#
#     configs["num_clients"] = num_clients
#     configs["clients"] = clients
#     configs["model"] = model
#     configs["old_datamodule"] = expand_flatdict_to_tree(run.config, "datamodule")["datamodule"]
#     return OmegaConf.create(configs)


def expand_flatdict_to_tree(config, key_to_expand, sep='/'):
    tree = {}

    for src, val in config.items():
        ref = tree
        if src.split(sep)[0] == key_to_expand:
            for i, part in enumerate(src.split(sep)):
                if part not in ref:
                    ref[part] = {}
                if i == len(src.split(sep)) - 1:
                    # we cannot do ref = val after loop, as assignment to the ref itself will be passed by assignment
                    ref[part] = val
                    break
                ref = ref[part]  # update nest reference

    return tree



from collections import Counter, defaultdict
import numpy as np
import torch

# def log_splits_info(datasets_train, datasets_val, fair_val=False):
#     splits_info = {}
#     for i, (ds_t, ds_v) in enumerate(zip(datasets_train, datasets_val)):
#         # Handling different dataset structures for training data
#         if isinstance(ds_t, Subset):
#             train_indices = ds_t.indices
#             train_labels = ds_t.dataset.targets[train_indices]
#         else:
#             train_indices = range(len(ds_t))
#             train_labels = ds_t.targets if hasattr(ds_t, 'targets') else [ds_t[i][1] for i in train_indices]
#
#         # Handling different dataset structures for validation data
#         if isinstance(ds_v, Subset):
#             val_indices = ds_v.indices
#             val_labels = ds_v.dataset.targets[val_indices]
#         else:
#             val_indices = range(len(ds_v))
#             val_labels = ds_v.targets if hasattr(ds_v, 'targets') else [ds_v[i][1] for i in val_indices]
#
#         # Convert labels to numpy for consistent processing
#         train_labels = np.array(train_labels)
#         val_labels = np.array(val_labels)
#
#         # Prepare split info
#         splits_info[f'client_{i}'] = {
#             'train_labels_count': dict(Counter(train_labels)),
#             'val_labels_count': dict(Counter(val_labels)),
#             'train_data_indices': list(train_indices),
#             'val_data_indices': list(val_indices)
#         }
#
#         # Distribution information
#         unique_train, counts_train = np.unique(train_labels, return_counts=True)
#         unique_val, counts_val = np.unique(val_labels, return_counts=True)
#         splits_info[f'client_{i}']['train_labels_distribution'] = dict(zip(unique_train, counts_train / len(train_indices) * 100))
#         splits_info[f'client_{i}']['val_labels_distribution'] = dict(zip(unique_val, counts_val / len(val_indices) * 100))
#         # print(f"splits_info: {splits_info}")
#
#     # Ensuring no intersection in data splits across clients
#     train_indices_sets = [set(info['train_data_indices']) for info in splits_info.values()]
#     val_indices_sets = [set(info['val_data_indices']) for info in splits_info.values()]
#
#     # Check intersection across clients for training and validation sets separately
#     for idx, train_set in enumerate(train_indices_sets):
#         for jdx, other_train_set in enumerate(train_indices_sets):
#             if idx != jdx:
#                 assert train_set.isdisjoint(other_train_set), f"Train data indices for client {idx} and client {jdx} overlap."
#
#     # for idx, val_set in enumerate(val_indices_sets):
#     #     for jdx, other_val_set in enumerate(val_indices_sets):
#     #         if idx != jdx:
#     #             assert val_set.isdisjoint(other_val_set), f"Validation data indices for client {idx} and client {jdx} overlap."
#
#     if not fair_val:
#         # Additionally check train-validation intersection for each client if fair validation is not used
#         for idx, train_set in enumerate(train_indices_sets):
#             for jdx, val_set in enumerate(val_indices_sets):
#                 assert train_set.isdisjoint(val_set), f"Train and validation data indices for client {idx} overlap."
#
#     return splits_info



def log_splits_info(datasets_train: [Subset], datasets_val: [Subset], fair_val: bool = False):
    for ds_t, ds_v in zip(datasets_train, datasets_val):
        ds_t.indices = list(ds_t.indices)
        ds_v.indices = list(ds_v.indices)
    splits_info = {
        f"client_{i}": {
            "labels_count": dict(Counter(
                map(str, ds_t.targets[ds_t.indices + ds_v.indices].numpy())
            )),
            "labels_distribution": ...,
            "train_labels_count": dict(Counter(
                map(str, ds_t.targets[ds_t.indices].numpy())
            )),
            "train_labels_distribution": ...,
            "train_data_indices": ds_t.indices,
            "val_labels_count": dict(Counter(
                map(str, ds_t.targets[ds_v.indices].numpy())
            )),
            "val_labels_distribution": ...,
            "val_data_indices": ds_v.indices,
        }
        for i, (ds_t, ds_v) in enumerate(zip(datasets_train, datasets_val))
    }
    for client, info in splits_info.items():
        client_idx = int(client.split('_')[-1])
        labels = datasets_train[
            client_idx].targets  # get all labels of the full dataset, then access each client data based on indices

        # Calc for all the client data
        train_indices = splits_info[client]["train_data_indices"]
        val_indices = splits_info[client]["val_data_indices"]
        length = len(train_indices + val_indices)
        uniques, counts = np.unique(labels[train_indices + val_indices].numpy(), return_counts=True)
        splits_info[client]["labels_distribution"] = dict(zip(map(str, uniques), counts / length * 100))

        # Calc for the client training data
        uniques, counts = np.unique(labels[train_indices].numpy(), return_counts=True)
        splits_info[client]["train_labels_distribution"] = dict(
            zip(map(str, uniques), counts / len(train_indices) * 100))

        # Calc for the client val data
        uniques, counts = np.unique(labels[val_indices].numpy(), return_counts=True)
        splits_info[client]["val_labels_distribution"] = dict(zip(map(str, uniques), counts / len(val_indices) * 100))
    # Sanity check to make sure there is no intersections between splits
    # compare all train splits agaisnst each other
    assert all(
        (set(splits_info[client1]["train_data_indices"]).isdisjoint(set(splits_info[client2]["train_data_indices"])))
        for client1, info in splits_info.items() for client2, info in splits_info.items() if client1 != client2
    )
    # compare all val splits against each other
    assert all(
        (set(splits_info[client1]["val_data_indices"]).isdisjoint(set(splits_info[client2]["val_data_indices"])))
        for client1, info in splits_info.items() for client2, info in splits_info.items() if client1 != client2
    ) or fair_val
    # compare all train vs all val
    assert all(
        (set(splits_info[client1]["train_data_indices"]).isdisjoint(set(splits_info[client2]["val_data_indices"])))
        for client1, info in splits_info.items() for client2, info in splits_info.items()
    )
    return splits_info
